package org.sunxin.ch04.servlet;

import javax.servlet.*;
import java.io.*;
import javax.servlet.http.*;
import java.sql.*;

public class TradeServlet extends HttpServlet
{
    private String url;
    private String user;
    private String password;
    
    public void init() throws ServletException
    {
    	/**
    	 * 通过 servlet上下文获得数据库连接参数
    	 */
        ServletContext sc=getServletContext();
        String driverClass=sc.getInitParameter("driverClass");
        url=sc.getInitParameter("url");
        user=sc.getInitParameter("user");
        password=sc.getInitParameter("password");
        try
        {
            Class.forName(driverClass);
        }
        catch(ClassNotFoundException ce)
        {
            throw new ServletException("加载数据库驱动失败！");
        }
    }
    
    public void doGet(HttpServletRequest req, HttpServletResponse resp)
               throws ServletException,IOException
    {
        Connection conn=null;
        Statement stmt=null;
        PreparedStatement pstmt=null;
        ResultSet rs=null;
        
        resp.setContentType("text/html;charset=gb2312");
        PrintWriter out=resp.getWriter();
        
        req.setCharacterEncoding("gb2312");
        
        String userid=req.getParameter("userid");
        String quantity=req.getParameter("quantity");
        
        if(null==userid || userid.equals("") || 
           null==quantity || quantity.equals(""))
        {
            
            out.println("错误的请求参数");
            out.close();
        }
        else
        {
            try
            {
                conn=DriverManager.getConnection(url,user,password);
                
                conn.setAutoCommit(false);
                conn.setTransactionIsolation(Connection.TRANSACTION_REPEATABLE_READ);
                
                stmt=conn.createStatement();
                rs=stmt.executeQuery("select price,amount from bookinfo where id=3");
                rs.next();
                float price=rs.getFloat(1);
                int amount=rs.getInt(2);
                
                int num=Integer.parseInt(quantity);
                if(amount>=num)
                {
                    pstmt=conn.prepareStatement("update bookinfo set amount = ? where id = 3");
                    pstmt.setInt(1,amount-num);
                    pstmt.executeUpdate();
                }
                else
                {
                    out.println("您所购买的图书库存数量不足。");
                    out.close();
                    return;
                }
                pstmt=conn.prepareStatement("select balance from account where userid = ?");
                pstmt.setString(1,userid);
                rs=pstmt.executeQuery();
                
                rs.next();
                float balance=rs.getFloat(1);
                
                float totalPrice=price*num;
                
                if(balance>=totalPrice)
                {
                    pstmt=conn.prepareStatement("update account set balance = ? where userid = ?");
                    pstmt.setFloat(1,balance-totalPrice);
                    pstmt.setString(2,userid);
                    pstmt.executeUpdate();
                }
                else
                {
                    conn.rollback();
                    out.println("您的余额不足。");
                    out.close();
                    return;
                }
                conn.commit();
                out.println("交易成功!");
                out.close();
            }
            catch(SQLException se)
            {
                if(conn!=null)
                {
                    try
                    {
                        conn.rollback();
                    }
                    catch(SQLException sex)
                    {
                        sex.printStackTrace();
                    }
                }   
                se.printStackTrace();
            }
            finally
            {
                if(rs!=null)
                {
                    try
                    {
                        rs.close();
                    }
                    catch(SQLException se)
                    {
                        se.printStackTrace();
                    }
                    rs=null;
                }
                if(stmt!=null)
                {
                    try
                    {
                        stmt.close();
                    }
                    catch(SQLException se)
                    {
                        se.printStackTrace();
                    }
                    stmt=null;
                }
                if(pstmt!=null)
                {
                    try
                    {
                        pstmt.close();
                    }
                    catch(SQLException se)
                    {
                        se.printStackTrace();
                    }
                    pstmt=null;
                }
                if(conn!=null)
                {
                    try
                    {
                        conn.close();
                    }
                    catch(SQLException se)
                    {
                        se.printStackTrace();
                    }
                    conn=null;
                }
            }
        }
    }
    
    public void doPost(HttpServletRequest req, HttpServletResponse resp)
               throws ServletException,IOException
    {
        doGet(req,resp);
    }
}